{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Bayes by Backprop with ``gluon`` (NN, classification)\n",
    "\n",
    "After discussing [Bayes by Backprop from scratch](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop.ipynb) in a previous notebook, we can now look at the corresponding implementation as ``gluon`` components."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start off with the usual set of imports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:27.533285Z",
     "start_time": "2018-07-24T03:05:26.006406Z"
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import collections\n",
    "import mxnet as mx\n",
    "import numpy as np\n",
    "from mxnet import nd, autograd\n",
    "from matplotlib import pyplot as plt\n",
    "from mxnet import gluon"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For easy tuning and experimentation, we define a dictionary holding the hyper-parameters of our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:27.541116Z",
     "start_time": "2018-07-24T03:05:27.535661Z"
    }
   },
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"num_hidden_layers\": 2,\n",
    "    \"num_hidden_units\": 400, \n",
    "    \"batch_size\": 128,\n",
    "    \"epochs\": 10,\n",
    "    \"learning_rate\": 0.001,\n",
    "    \"num_samples\": 1,\n",
    "    \"pi\": 0.25,\n",
    "    \"sigma_p\": 1.0,\n",
    "    \"sigma_p1\": 0.75,\n",
    "    \"sigma_p2\": 0.01,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, we specify the device context for MXNet."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:27.549058Z",
     "start_time": "2018-07-24T03:05:27.544210Z"
    }
   },
   "outputs": [],
   "source": [
    "ctx = mx.cpu()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load MNIST data\n",
    "\n",
    "We will again train and evaluate the algorithm on the MNIST data set and therefore load the data set as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.168603Z",
     "start_time": "2018-07-24T03:05:27.552660Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def transform(data, label):\n",
    "    return data.astype(np.float32)/126.0, label.astype(np.float32)\n",
    "\n",
    "mnist = mx.test_utils.get_mnist()\n",
    "num_inputs = 784\n",
    "num_outputs = 10\n",
    "batch_size = config['batch_size']\n",
    "\n",
    "train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=transform),\n",
    "                                      batch_size, shuffle=True)\n",
    "test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=transform),\n",
    "                                     batch_size, shuffle=False)\n",
    "\n",
    "num_train = sum([batch_size for i in train_data])\n",
    "num_batches = num_train / batch_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to reproduce and compare the results from the paper, we preprocess the pixels by dividing by 126."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model definition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural net modeling\n",
    "\n",
    "As our model we are using a straightforward MLP and we are wiring up our network just as we are used to in ``gluon``. Note that we are not using any special layers during the definition of our network, as we believe that Bayes by Backprop should be thought of as a training method, rather than a special architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.176075Z",
     "start_time": "2018-07-24T03:05:40.170875Z"
    }
   },
   "outputs": [],
   "source": [
    "num_layers = config['num_hidden_layers']\n",
    "num_hidden = config['num_hidden_units']\n",
    "\n",
    "net = gluon.nn.Sequential()\n",
    "with net.name_scope():\n",
    "    for i in range(num_layers):\n",
    "        net.add(gluon.nn.Dense(num_hidden, activation=\"relu\"))\n",
    "    net.add(gluon.nn.Dense(num_outputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build objective/loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, we define our loss function as described in [Bayes by Backprop from scratch](https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter18_variational-methods-and-uncertainty/bayes-by-backprop.ipynb). Note that we are bundling all of this functionality as part of a ``gluon.loss.Loss`` subclass, where the loss computation is performed in the ``hybrid_forward`` function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.195610Z",
     "start_time": "2018-07-24T03:05:40.179761Z"
    }
   },
   "outputs": [],
   "source": [
    "class BBBLoss(gluon.loss.Loss):\n",
    "    def __init__(self, log_prior=\"gaussian\", log_likelihood=\"softmax_cross_entropy\", \n",
    "                 sigma_p1=1.0, sigma_p2=0.1, pi=0.5, weight=None, batch_axis=0, **kwargs):\n",
    "        super(BBBLoss, self).__init__(weight, batch_axis, **kwargs)\n",
    "        self.log_prior = log_prior\n",
    "        self.log_likelihood = log_likelihood\n",
    "        self.sigma_p1 = sigma_p1\n",
    "        self.sigma_p2 = sigma_p2\n",
    "        self.pi = pi\n",
    "    \n",
    "    def log_softmax_likelihood(self, yhat_linear, y):\n",
    "        return nd.nansum(y * nd.log_softmax(yhat_linear), axis=0, exclude=True)\n",
    "\n",
    "    def log_gaussian(self, x, mu, sigma):\n",
    "        return -0.5 * np.log(2.0 * np.pi) - nd.log(sigma) - (x - mu) ** 2 / (2 * sigma ** 2)\n",
    "\n",
    "    def gaussian_prior(self, x):\n",
    "        sigma_p = nd.array([self.sigma_p1], ctx=ctx)\n",
    "        return nd.sum(self.log_gaussian(x, 0., sigma_p))\n",
    "    \n",
    "    def gaussian(self, x, mu, sigma):\n",
    "        scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))\n",
    "        bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))\n",
    "\n",
    "        return scaling * bell\n",
    "\n",
    "    def scale_mixture_prior(self, x):\n",
    "        sigma_p1 = nd.array([self.sigma_p1], ctx=ctx)\n",
    "        sigma_p2 = nd.array([self.sigma_p2], ctx=ctx)\n",
    "        pi = self.pi\n",
    "\n",
    "        first_gaussian = pi * self.gaussian(x, 0., sigma_p1)\n",
    "        second_gaussian = (1 - pi) * self.gaussian(x, 0., sigma_p2)\n",
    "\n",
    "        return nd.log(first_gaussian + second_gaussian)\n",
    "        \n",
    "    def hybrid_forward(self, F, output, label, params, mus, sigmas, sample_weight=None):\n",
    "        log_likelihood_sum = nd.sum(self.log_softmax_likelihood(output, label))\n",
    "        prior = None\n",
    "        if self.log_prior == \"gaussian\":\n",
    "            prior = self.gaussian_prior\n",
    "        elif self.log_prior == \"scale_mixture\":\n",
    "            prior = self.scale_mixture_prior\n",
    "        log_prior_sum = sum([nd.sum(prior(param)) for param in params])\n",
    "        log_var_posterior_sum = sum([nd.sum(self.log_gaussian(params[i], mus[i], sigmas[i])) for i in range(len(params))])\n",
    "        return 1.0 / num_batches * (log_var_posterior_sum - log_prior_sum) - log_likelihood_sum\n",
    "    \n",
    "bbb_loss = BBBLoss(log_prior=\"scale_mixture\", sigma_p1=config['sigma_p1'], sigma_p2=config['sigma_p2'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter initialization\n",
    "\n",
    "First, we need to initialize all the network's parameters, which are only point estimates of the weights at this point. We will soon see, how we can still train the network in a Bayesian fashion, without interfering with the network's architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.216198Z",
     "start_time": "2018-07-24T03:05:40.199505Z"
    }
   },
   "outputs": [],
   "source": [
    "net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we have to forward-propagate a single data set entry once to set up all network parameters (weights and biases) with the desired initliaizer specified above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.326905Z",
     "start_time": "2018-07-24T03:05:40.219151Z"
    }
   },
   "outputs": [],
   "source": [
    "for i, (data, label) in enumerate(train_data):\n",
    "    data = data.as_in_context(ctx).reshape((-1, 784))\n",
    "    net(data)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.355042Z",
     "start_time": "2018-07-24T03:05:40.331261Z"
    }
   },
   "outputs": [],
   "source": [
    "weight_scale = .1\n",
    "rho_offset = -3\n",
    "\n",
    "# initialize variational parameters; mean and variance for each weight\n",
    "mus = []\n",
    "rhos = []\n",
    "\n",
    "shapes = list(map(lambda x: x.shape, net.collect_params().values()))\n",
    "\n",
    "for shape in shapes:\n",
    "    mu = gluon.Parameter('mu', shape=shape, init=mx.init.Normal(weight_scale))\n",
    "    rho = gluon.Parameter('rho',shape=shape, init=mx.init.Constant(rho_offset))\n",
    "    mu.initialize(ctx=ctx)\n",
    "    rho.initialize(ctx=ctx)\n",
    "    mus.append(mu)\n",
    "    rhos.append(rho)\n",
    "\n",
    "variational_params = mus + rhos\n",
    "\n",
    "raw_mus = list(map(lambda x: x.data(ctx), mus))\n",
    "raw_rhos = list(map(lambda x: x.data(ctx), rhos))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimizer\n",
    "\n",
    "Now, we still have to choose the optimizer we wish to use for training. This time, we are using the ``adam`` optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.372971Z",
     "start_time": "2018-07-24T03:05:40.363308Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(variational_params, 'adam', {'learning_rate': config['learning_rate']})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main training loop\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sampling\n",
    "\n",
    "Recall the 3-step process for the variational parameters:\n",
    "\n",
    "1) Sample $\\mathbf{\\epsilon} \\sim \\mathcal{N}(\\mathbf{0},\\mathbf{I}^{d})$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.388922Z",
     "start_time": "2018-07-24T03:05:40.381960Z"
    }
   },
   "outputs": [],
   "source": [
    "def sample_epsilons(param_shapes):\n",
    "    epsilons = [nd.random_normal(shape=shape, loc=0., scale=1.0, ctx=ctx) for shape in param_shapes]\n",
    "    return epsilons"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2) Transform $\\mathbf{\\rho}$ to a positive vector via the softplus function: $\\mathbf{\\sigma} = \\text{softplus}(\\mathbf{\\rho}) = \\log(1 + \\exp(\\mathbf{\\rho}))$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.400294Z",
     "start_time": "2018-07-24T03:05:40.393401Z"
    }
   },
   "outputs": [],
   "source": [
    "def softplus(x):\n",
    "    return nd.log(1. + nd.exp(x))\n",
    "\n",
    "def transform_rhos(rhos):\n",
    "    return [softplus(rho) for rho in rhos]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3) Compute $\\mathbf{w}$: $\\mathbf{w} = \\mathbf{\\mu} + \\mathbf{\\sigma} \\circ \\mathbf{\\epsilon}$, where the $\\circ$ operator represents the element-wise multiplication. This is the \"reparametrization trick\" for separating the randomness from the parameters of $q$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.417479Z",
     "start_time": "2018-07-24T03:05:40.408643Z"
    }
   },
   "outputs": [],
   "source": [
    "def transform_gaussian_samples(mus, sigmas, epsilons):\n",
    "    samples = []\n",
    "    for j in range(len(mus)):\n",
    "        samples.append(mus[j] + sigmas[j] * epsilons[j])\n",
    "    return samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Putting these three steps together we get:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.428798Z",
     "start_time": "2018-07-24T03:05:40.421868Z"
    }
   },
   "outputs": [],
   "source": [
    "def generate_weight_sample(layer_param_shapes, mus, rhos):\n",
    "    # sample epsilons from standard normal\n",
    "    epsilons = sample_epsilons(layer_param_shapes)\n",
    "\n",
    "    # compute softplus for variance\n",
    "    sigmas = transform_rhos(rhos)\n",
    "\n",
    "    # obtain a sample from q(w|theta) by transforming the epsilons\n",
    "    layer_params = transform_gaussian_samples(mus, sigmas, epsilons)\n",
    "    \n",
    "    return layer_params, sigmas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Evaluation metric\n",
    "\n",
    "In order to being able to assess our model performance we define a helper function which evaluates our accuracy on an ongoing basis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:05:40.441090Z",
     "start_time": "2018-07-24T03:05:40.432038Z"
    }
   },
   "outputs": [],
   "source": [
    "def evaluate_accuracy(data_iterator, net, layer_params):\n",
    "    numerator = 0.\n",
    "    denominator = 0.\n",
    "    for i, (data, label) in enumerate(data_iterator):\n",
    "        data = data.as_in_context(ctx).reshape((-1, 784))\n",
    "        label = label.as_in_context(ctx)\n",
    "        \n",
    "        for l_param, param in zip(layer_params, net.collect_params().values()):\n",
    "            param._data[0] = l_param\n",
    "        \n",
    "        output = net(data)\n",
    "        predictions = nd.argmax(output, axis=1)\n",
    "        numerator += nd.sum(predictions == label)\n",
    "        denominator += data.shape[0]\n",
    "    return (numerator / denominator).asscalar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complete loop\n",
    "\n",
    "The complete training loop is given below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.242959Z",
     "start_time": "2018-07-24T03:05:40.446256Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "epochs = config['epochs']\n",
    "learning_rate = config['learning_rate']\n",
    "smoothing_constant = .01\n",
    "train_acc = []\n",
    "test_acc = []\n",
    "\n",
    "for e in range(epochs):\n",
    "    for i, (data, label) in enumerate(train_data): \n",
    "        data = data.as_in_context(ctx).reshape((-1, 784))\n",
    "        label = label.as_in_context(ctx)\n",
    "        label_one_hot = nd.one_hot(label, 10)\n",
    "        \n",
    "        with autograd.record():\n",
    "            # generate sample\n",
    "            layer_params, sigmas = generate_weight_sample(shapes, raw_mus, raw_rhos)\n",
    "            \n",
    "            # overwrite network parameters with sampled parameters \n",
    "            for sample, param in zip(layer_params, net.collect_params().values()):\n",
    "                param._data[0] = sample\n",
    "                \n",
    "            # forward-propagate the batch \n",
    "            output = net(data)\n",
    "            \n",
    "            # calculate the loss\n",
    "            loss = bbb_loss(output, label_one_hot, layer_params, raw_mus, sigmas)\n",
    "            \n",
    "            # backpropagate for gradient calculation\n",
    "            loss.backward()\n",
    "        \n",
    "        trainer.step(data.shape[0])\n",
    "        \n",
    "        # calculate moving loss for monitoring convergence\n",
    "        curr_loss = nd.mean(loss).asscalar()\n",
    "        moving_loss = (curr_loss if ((i == 0) and (e == 0)) \n",
    "                       else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)\n",
    "    \n",
    "    test_accuracy = evaluate_accuracy(test_data, net, raw_mus)\n",
    "    train_accuracy = evaluate_accuracy(train_data, net, raw_mus)\n",
    "    train_acc.append(np.asscalar(train_accuracy))\n",
    "    test_acc.append(np.asscalar(test_accuracy))\n",
    "    print(\"Epoch %s. Loss: %s, Train_acc %s, Test_acc %s\" %\n",
    "          (e, moving_loss, train_accuracy, test_accuracy))\n",
    "    \n",
    "plt.plot(train_acc)\n",
    "plt.plot(test_acc)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For demonstration purposes, we can now take a look at one particular weight by plotting its distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.442889Z",
     "start_time": "2018-07-24T03:25:57.246459Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def gaussian(x, mu, sigma):\n",
    "    scaling = 1.0 / nd.sqrt(2.0 * np.pi * (sigma ** 2))\n",
    "    bell = nd.exp(- (x - mu) ** 2 / (2.0 * sigma ** 2))\n",
    "\n",
    "    return scaling * bell\n",
    "\n",
    "def show_weight_dist(mean, variance):\n",
    "    sigma = nd.sqrt(variance)\n",
    "    x = np.linspace(mean.asscalar() - 4*sigma.asscalar(), mean.asscalar() + 4*sigma.asscalar(), 100)\n",
    "    plt.plot(x, gaussian(nd.array(x, ctx=ctx), mean, sigma).asnumpy())\n",
    "    plt.show()\n",
    "    \n",
    "mu = raw_mus[0][0][0]\n",
    "var = softplus(raw_rhos[0][0][0]) ** 2\n",
    "\n",
    "show_weight_dist(mu, var)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Weight pruning\n",
    "\n",
    "To measure the degree of redundancy present in the trained network and to reduce the model's parameter count, we now want to examine the effect of setting some of the weights to $0$ and evaluate the test accuracy afterwards. We can achieve this by ordering the weights according to their signal-to-noise-ratio, $\\frac{|\\mu_i|}{\\sigma_i}$, and setting a certain percentage of the weights with the lowest ratios to $0$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can calculate the signal-to-noise-ratio as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.451276Z",
     "start_time": "2018-07-24T03:25:57.446032Z"
    }
   },
   "outputs": [],
   "source": [
    "def signal_to_noise_ratio(mus, sigmas):\n",
    "    sign_to_noise = []\n",
    "    for j in range(len(mus)):\n",
    "        sign_to_noise.extend([nd.abs(mus[j]) / sigmas[j]])\n",
    "    return sign_to_noise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We further introduce a few helper methods which turn our list of weights into a single vector containing all weights. This will make our subsequent actions easier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.466128Z",
     "start_time": "2018-07-24T03:25:57.459218Z"
    }
   },
   "outputs": [],
   "source": [
    "def vectorize_matrices_in_vector(vec):\n",
    "    for i in range(0, (num_layers + 1) * 2, 2):\n",
    "        if i == 0:\n",
    "            vec[i] = nd.reshape(vec[i], num_inputs * num_hidden)\n",
    "        elif i == num_layers * 2:\n",
    "            vec[i] = nd.reshape(vec[i], num_hidden * num_outputs)\n",
    "        else:\n",
    "            vec[i] = nd.reshape(vec[i], num_hidden * num_hidden)\n",
    "            \n",
    "    return vec\n",
    "\n",
    "def concact_vectors_in_vector(vec):\n",
    "    concat_vec = vec[0]\n",
    "    for i in range(1, len(vec)):\n",
    "        concat_vec = nd.concat(concat_vec, vec[i], dim=0)\n",
    "    \n",
    "    return concat_vec\n",
    "\n",
    "def transform_vector_structure(vec):\n",
    "    vec = vectorize_matrices_in_vector(vec)\n",
    "    vec = concact_vectors_in_vector(vec)\n",
    "    \n",
    "    return vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition, we also have a helper method which transforms the pruned weight vector back to the original layered structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.479438Z",
     "start_time": "2018-07-24T03:25:57.470930Z"
    }
   },
   "outputs": [],
   "source": [
    "from functools import reduce\n",
    "import operator\n",
    "\n",
    "def prod(iterable):\n",
    "    return reduce(operator.mul, iterable, 1)\n",
    "\n",
    "def restore_weight_structure(vec):\n",
    "    pruned_weights = []\n",
    "    \n",
    "    index = 0\n",
    "    \n",
    "    for shape in shapes:\n",
    "        incr = prod(shape)\n",
    "        pruned_weights.extend([nd.reshape(vec[index : index + incr], shape)])\n",
    "        index += incr\n",
    "    \n",
    "    return pruned_weights"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The actual pruning of the vector happens in the following function. Note that this function accepts an ordered list of percentages to evaluate the performance at different pruning rates. In this setting, pruning at each iteration means extracting the index of the lowest signal-to-noise-ratio weight and setting the weight at this index to $0$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:25:57.490140Z",
     "start_time": "2018-07-24T03:25:57.483328Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def prune_weights(sign_to_noise_vec, prediction_vector, percentages):\n",
    "    pruning_indices = nd.argsort(sign_to_noise_vec, axis=0)\n",
    "    \n",
    "    for percentage in percentages:\n",
    "        prediction_vector = mus_copy_vec.copy()\n",
    "        pruning_indices_percent = pruning_indices[0:int(len(pruning_indices)*percentage)]\n",
    "        for pr_ind in pruning_indices_percent:\n",
    "            prediction_vector[int(pr_ind.asscalar())] = 0\n",
    "        pruned_weights = restore_weight_structure(prediction_vector)\n",
    "        test_accuracy = evaluate_accuracy(test_data, net, pruned_weights)\n",
    "        print(\"%s --> %s\" % (percentage, test_accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Putting the above function together:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-07-24T03:29:45.790199Z",
     "start_time": "2018-07-24T03:25:57.494190Z"
    }
   },
   "outputs": [],
   "source": [
    "sign_to_noise = signal_to_noise_ratio(raw_mus, sigmas)\n",
    "sign_to_noise_vec = transform_vector_structure(sign_to_noise)\n",
    "\n",
    "mus_copy = raw_mus.copy()\n",
    "mus_copy_vec = transform_vector_structure(mus_copy)\n",
    "\n",
    "prune_weights(sign_to_noise_vec, mus_copy_vec, [0.1, 0.25, 0.5, 0.75, 0.95, 0.98, 1.0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Depending on the number of units used in the original network, the highest achievable pruning percentages (without significantly reducing the predictive performance) can vary. The paper, for example, reports almost no change in the test accuracy when pruning 95% of the weights in a 1200 unit Bayesian neural network, which creates a significantly sparser network, leading to faster predictions and reduced memory requirements."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "We have taken a look at an efficient Bayesian treatment for neural networks using variational inference via the \"Bayes by Backprop\" algorithm (introduced by the \"[Weight Uncertainity in Neural Networks](https://arxiv.org/abs/1505.05424)\" paper). We have implemented a stochastic version of the variational lower bound and optimized it in order to find an approximation to the posterior distribution over the weights of a MLP network on the MNIST data set. As a result, we achieve regularization on the network's parameters and can quantify our uncertainty about the weights accurately. Finally, we saw that it is possible to significantly reduce the number of weights in the neural network after training while still keeping a high accuracy on the test set.\n",
    "\n",
    "We also note that, given this model implementation, we were able to reproduce the paper's results on the MNIST data set, achieving a comparable test accuracy for all documented instances of the MNIST classification problem."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
